<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - svm_multiclass_linear_trainer.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2011  Davis E. King (davis@dlib.net)
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_SVm_MULTICLASS_LINEAR_TRAINER_H__ 
<font color='#0000FF'>#define</font> DLIB_SVm_MULTICLASS_LINEAR_TRAINER_H__

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='svm_multiclass_linear_trainer_abstract.h.html'>svm_multiclass_linear_trainer_abstract.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='structural_svm_problem_threaded.h.html'>structural_svm_problem_threaded.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>vector<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../optimization/optimization_oca.h.html'>../optimization/optimization_oca.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../matrix.h.html'>../matrix.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='sparse_vector.h.html'>sparse_vector.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='function.h.html'>function.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>algorithm<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>namespace</font> dlib
<b>{</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> matrix_type,
        <font color='#0000FF'>typename</font> sample_type,
        <font color='#0000FF'>typename</font> label_type
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'>class</font> <b><a name='multiclass_svm_problem'></a>multiclass_svm_problem</b> : <font color='#0000FF'>public</font> structural_svm_problem_threaded<font color='#5555FF'>&lt;</font>matrix_type,
                                                                 std::vector<font color='#5555FF'>&lt;</font>std::pair<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font>,<font color='#0000FF'>typename</font> matrix_type::type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> 
    <b>{</b>
        <font color='#009900'>/*!
            WHAT THIS OBJECT REPRESENTS
                This object defines the optimization problem for the multiclass SVM trainer
                object at the bottom of this file.  

                The joint feature vectors used by this object, the PSI(x,y) vectors, are
                defined as follows:
                    PSI(x,0) = [x,0,0,0,0, ...,0]
                    PSI(x,1) = [0,x,0,0,0, ...,0]
                    PSI(x,2) = [0,0,x,0,0, ...,0]
                That is, if there are N labels then the joint feature vector has a
                dimension that is N times the dimension of a single x sample.  Also,
                note that we append a -1 value onto each x to account for the bias term.
        !*/</font>

    <font color='#0000FF'>public</font>:
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> matrix_type::type scalar_type;
        <font color='#0000FF'>typedef</font> std::vector<font color='#5555FF'>&lt;</font>std::pair<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font>,scalar_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> feature_vector_type;

        <b><a name='multiclass_svm_problem'></a>multiclass_svm_problem</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> samples_,
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> labels_,
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> distinct_labels_,
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> dims_,
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num_threads
        <font face='Lucida Console'>)</font> :
            structural_svm_problem_threaded<font color='#5555FF'>&lt;</font>matrix_type, std::vector<font color='#5555FF'>&lt;</font>std::pair<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font>,<font color='#0000FF'>typename</font> matrix_type::type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>num_threads<font face='Lucida Console'>)</font>,
            samples<font face='Lucida Console'>(</font>samples_<font face='Lucida Console'>)</font>,
            labels<font face='Lucida Console'>(</font>labels_<font face='Lucida Console'>)</font>,
            distinct_labels<font face='Lucida Console'>(</font>distinct_labels_<font face='Lucida Console'>)</font>,
            dims<font face='Lucida Console'>(</font>dims_<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font> <font color='#009900'>// +1 for the bias
</font>        <b>{</b><b>}</b>

        <font color='#0000FF'>virtual</font> <font color='#0000FF'><u>long</u></font> <b><a name='get_num_dimensions'></a>get_num_dimensions</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> dims<font color='#5555FF'>*</font>distinct_labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'>virtual</font> <font color='#0000FF'><u>long</u></font> <b><a name='get_num_samples'></a>get_num_samples</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> 
        <b>{</b>
            <font color='#0000FF'>return</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'>virtual</font> <font color='#0000FF'><u>void</u></font> <b><a name='get_truth_joint_feature_vector'></a>get_truth_joint_feature_vector</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>long</u></font> idx,
            feature_vector_type<font color='#5555FF'>&amp;</font> psi
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> 
        <b>{</b>
            <font color='#BB00BB'>assign</font><font face='Lucida Console'>(</font>psi, samples[idx]<font face='Lucida Console'>)</font>;
            <font color='#009900'>// Add a constant -1 to account for the bias term.
</font>            psi.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>make_pair</font><font face='Lucida Console'>(</font>dims<font color='#5555FF'>-</font><font color='#979000'>1</font>,<font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>scalar_type<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

            <font color='#009900'>// Find which distinct label goes with this psi.
</font>            <font color='#0000FF'><u>long</u></font> label_idx <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> distinct_labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>distinct_labels[i] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels[idx]<font face='Lucida Console'>)</font>
                <b>{</b>
                    label_idx <font color='#5555FF'>=</font> i;
                    <font color='#0000FF'>break</font>;
                <b>}</b>
            <b>}</b>

            <font color='#BB00BB'>offset_feature_vector</font><font face='Lucida Console'>(</font>psi, dims<font color='#5555FF'>*</font>label_idx<font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'>virtual</font> <font color='#0000FF'><u>void</u></font> <b><a name='separation_oracle'></a>separation_oracle</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> idx,
            <font color='#0000FF'>const</font> matrix_type<font color='#5555FF'>&amp;</font> current_solution,
            scalar_type<font color='#5555FF'>&amp;</font> loss,
            feature_vector_type<font color='#5555FF'>&amp;</font> psi
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> 
        <b>{</b>
            scalar_type best_val <font color='#5555FF'>=</font> <font color='#5555FF'>-</font>std::numeric_limits<font color='#5555FF'>&lt;</font>scalar_type<font color='#5555FF'>&gt;</font>::<font color='#BB00BB'>infinity</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> best_idx <font color='#5555FF'>=</font> <font color='#979000'>0</font>;

            <font color='#009900'>// Figure out which label is the best.  That is, what label maximizes
</font>            <font color='#009900'>// LOSS(idx,y) + F(x,y).  Note that y in this case is given by distinct_labels[i].
</font>            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> distinct_labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#009900'>// Compute the F(x,y) part:
</font>                <font color='#009900'>// perform: temp == dot(relevant part of current solution, samples[idx]) - current_bias
</font>                scalar_type temp <font color='#5555FF'>=</font> <font color='#BB00BB'>dot</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font><font color='#BB00BB'>current_solution</font><font face='Lucida Console'>(</font>i<font color='#5555FF'>*</font>dims<font face='Lucida Console'>)</font>,dims<font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>, samples[idx]<font face='Lucida Console'>)</font> <font color='#5555FF'>-</font> <font color='#BB00BB'>current_solution</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>i<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font>dims<font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;

                <font color='#009900'>// Add the LOSS(idx,y) part:
</font>                <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>labels[idx] <font color='#5555FF'>!</font><font color='#5555FF'>=</font> distinct_labels[i]<font face='Lucida Console'>)</font>
                    temp <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>1</font>;

                <font color='#009900'>// Now temp == LOSS(idx,y) + F(x,y).  Check if it is the biggest we have seen.
</font>                <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>temp <font color='#5555FF'>&gt;</font> best_val<font face='Lucida Console'>)</font>
                <b>{</b>
                    best_val <font color='#5555FF'>=</font> temp;
                    best_idx <font color='#5555FF'>=</font> i;
                <b>}</b>
            <b>}</b>

            <font color='#BB00BB'>assign</font><font face='Lucida Console'>(</font>psi, samples[idx]<font face='Lucida Console'>)</font>;
            <font color='#009900'>// add a constant -1 to account for the bias term
</font>            psi.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>make_pair</font><font face='Lucida Console'>(</font>dims<font color='#5555FF'>-</font><font color='#979000'>1</font>,<font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>scalar_type<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

            <font color='#BB00BB'>offset_feature_vector</font><font face='Lucida Console'>(</font>psi, dims<font color='#5555FF'>*</font>best_idx<font face='Lucida Console'>)</font>;

            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>distinct_labels[best_idx] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels[idx]<font face='Lucida Console'>)</font>
                loss <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
            <font color='#0000FF'>else</font>
                loss <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
        <b>}</b>

    <font color='#0000FF'>private</font>:

        <font color='#0000FF'><u>void</u></font> <b><a name='offset_feature_vector'></a>offset_feature_vector</b> <font face='Lucida Console'>(</font>
            feature_vector_type<font color='#5555FF'>&amp;</font> sample,
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> val
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>val <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>typename</font> feature_vector_type::iterator i <font color='#5555FF'>=</font> sample.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; i <font color='#5555FF'>!</font><font color='#5555FF'>=</font> sample.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
                <b>{</b>
                    i<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>first <font color='#5555FF'>+</font><font color='#5555FF'>=</font> val;
                <b>}</b>
            <b>}</b>
        <b>}</b>


        <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> samples;
        <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> labels;
        <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> distinct_labels;
        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> dims;
    <b>}</b>;


<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> K,
        <font color='#0000FF'>typename</font> label_type_ <font color='#5555FF'>=</font> <font color='#0000FF'>typename</font> K::scalar_type 
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'>class</font> <b><a name='svm_multiclass_linear_trainer'></a>svm_multiclass_linear_trainer</b>
    <b>{</b>
    <font color='#0000FF'>public</font>:
        <font color='#0000FF'>typedef</font> label_type_ label_type;
        <font color='#0000FF'>typedef</font> K kernel_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::scalar_type scalar_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::sample_type sample_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::mem_manager_type mem_manager_type;

        <font color='#0000FF'>typedef</font> multiclass_linear_decision_function<font color='#5555FF'>&lt;</font>kernel_type, label_type<font color='#5555FF'>&gt;</font> trained_function_type;


        <font color='#009900'>// You are getting a compiler error on this line because you supplied a non-linear kernel
</font>        <font color='#009900'>// to the svm_c_linear_trainer object.  You have to use one of the linear kernels with this
</font>        <font color='#009900'>// trainer.
</font>        <b><a name='COMPILE_TIME_ASSERT'></a>COMPILE_TIME_ASSERT</b><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>is_same_type<font color='#5555FF'>&lt;</font>K, linear_kernel<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font>::value <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
                             is_same_type<font color='#5555FF'>&lt;</font>K, sparse_linear_kernel<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font>::value <font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

        <b><a name='svm_multiclass_linear_trainer'></a>svm_multiclass_linear_trainer</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> :
            num_threads<font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>,
            C<font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>,
            eps<font face='Lucida Console'>(</font><font color='#979000'>0.001</font><font face='Lucida Console'>)</font>,
            verbose<font face='Lucida Console'>(</font><font color='#979000'>false</font><font face='Lucida Console'>)</font>,
            learn_nonnegative_weights<font face='Lucida Console'>(</font><font color='#979000'>false</font><font face='Lucida Console'>)</font>
        <b>{</b>
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_num_threads'></a>set_num_threads</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num
        <font face='Lucida Console'>)</font>
        <b>{</b>
            num_threads <font color='#5555FF'>=</font> num;
        <b>}</b>

        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_num_threads'></a>get_num_threads</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> num_threads;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_epsilon'></a>set_epsilon</b> <font face='Lucida Console'>(</font>
            scalar_type eps_
        <font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#009900'>// make sure requires clause is not broken
</font>            <font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font>eps_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
                "<font color='#CC0000'>\t void svm_multiclass_linear_trainer::set_epsilon()</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t eps_ must be greater than 0</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t eps_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> eps_ 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#0000FF'>this</font>
                <font face='Lucida Console'>)</font>;

            eps <font color='#5555FF'>=</font> eps_;
        <b>}</b>

        <font color='#0000FF'>const</font> scalar_type <b><a name='get_epsilon'></a>get_epsilon</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> eps; <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='be_verbose'></a>be_verbose</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>
        <b>{</b>
            verbose <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='be_quiet'></a>be_quiet</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>
        <b>{</b>
            verbose <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_oca'></a>set_oca</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> oca<font color='#5555FF'>&amp;</font> item
        <font face='Lucida Console'>)</font>
        <b>{</b>
            solver <font color='#5555FF'>=</font> item;
        <b>}</b>

        <font color='#0000FF'>const</font> oca <b><a name='get_oca'></a>get_oca</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> solver;
        <b>}</b>

        <font color='#0000FF'>const</font> kernel_type <b><a name='get_kernel'></a>get_kernel</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> <font color='#BB00BB'>kernel_type</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'><u>bool</u></font> <b><a name='learns_nonnegative_weights'></a>learns_nonnegative_weights</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> learn_nonnegative_weights; <b>}</b>
       
        <font color='#0000FF'><u>void</u></font> <b><a name='set_learns_nonnegative_weights'></a>set_learns_nonnegative_weights</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>bool</u></font> value
        <font face='Lucida Console'>)</font>
        <b>{</b>
            learn_nonnegative_weights <font color='#5555FF'>=</font> value;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learn_nonnegative_weights<font face='Lucida Console'>)</font>
                prior <font color='#5555FF'>=</font> <font color='#BB00BB'>trained_function_type</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; 
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_c'></a>set_c</b> <font face='Lucida Console'>(</font>
            scalar_type C_
        <font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#009900'>// make sure requires clause is not broken
</font>            <font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font>C_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
                "<font color='#CC0000'>\t void svm_multiclass_linear_trainer::set_c()</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t C must be greater than 0</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t C_:   </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> C_ 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#0000FF'>this</font>
                <font face='Lucida Console'>)</font>;

            C <font color='#5555FF'>=</font> C_;
        <b>}</b>

        <font color='#0000FF'>const</font> scalar_type <b><a name='get_c'></a>get_c</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> C;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_prior'></a>set_prior</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> trained_function_type<font color='#5555FF'>&amp;</font> prior_
        <font face='Lucida Console'>)</font>
        <b>{</b>
            prior <font color='#5555FF'>=</font> prior_;
            learn_nonnegative_weights <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
        <b>}</b>

        <font color='#0000FF'><u>bool</u></font> <b><a name='has_prior'></a>has_prior</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> prior.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font>;
        <b>}</b>

        trained_function_type <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> all_samples,
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> all_labels
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            scalar_type svm_objective <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
            <font color='#0000FF'>return</font> <font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>all_samples, all_labels, svm_objective<font face='Lucida Console'>)</font>;
        <b>}</b>

        trained_function_type <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> all_samples,
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> all_labels,
            scalar_type<font color='#5555FF'>&amp;</font> svm_objective
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#009900'>// make sure requires clause is not broken
</font>            <font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_learning_problem</font><font face='Lucida Console'>(</font>all_samples,all_labels<font face='Lucida Console'>)</font>,
                "<font color='#CC0000'>\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t invalid inputs were given to this function</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t all_samples.size():     </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> all_samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t all_labels.size():      </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> all_labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font face='Lucida Console'>)</font>;

            trained_function_type df;
            df.labels <font color='#5555FF'>=</font> <font color='#BB00BB'>select_all_distinct_labels</font><font face='Lucida Console'>(</font>all_labels<font face='Lucida Console'>)</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>has_prior</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
            <b>{</b>
                df.labels.<font color='#BB00BB'>insert</font><font face='Lucida Console'>(</font>df.labels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, prior.labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, prior.labels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
                df.labels <font color='#5555FF'>=</font> <font color='#BB00BB'>select_all_distinct_labels</font><font face='Lucida Console'>(</font>df.labels<font face='Lucida Console'>)</font>;
            <b>}</b>
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> input_sample_dimensionality <font color='#5555FF'>=</font> <font color='#BB00BB'>max_index_plus_one</font><font face='Lucida Console'>(</font>all_samples<font face='Lucida Console'>)</font>;
            <font color='#009900'>// If the samples are sparse then the right thing to do is to take the max
</font>            <font color='#009900'>// dimensionality between the prior and the new samples.  But if the samples
</font>            <font color='#009900'>// are dense vectors then they definitely all have to have exactly the same
</font>            <font color='#009900'>// dimensionality.
</font>            <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> dims <font color='#5555FF'>=</font> std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>df.weights.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,input_sample_dimensionality<font face='Lucida Console'>)</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>is_matrix<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font>::value <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> <font color='#BB00BB'>has_prior</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font>input_sample_dimensionality <font color='#5555FF'>=</font><font color='#5555FF'>=</font> prior.weights.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, 
                    "<font color='#CC0000'>\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)</font>"
                    <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t The training samples given to this function are not the same kind of training </font>"
                    <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t samples used to create the prior.</font>"
                    <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t input_sample_dimensionality: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> input_sample_dimensionality 
                    <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t prior.weights.nc():          </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> prior.weights.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font face='Lucida Console'>)</font>;
            <b>}</b>

            <font color='#0000FF'>typedef</font> matrix<font color='#5555FF'>&lt;</font>scalar_type,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font> w_type;
            w_type weights;
            multiclass_svm_problem<font color='#5555FF'>&lt;</font>w_type, sample_type, label_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>problem</font><font face='Lucida Console'>(</font>all_samples, all_labels, df.labels, dims, num_threads<font face='Lucida Console'>)</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
                problem.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

            problem.<font color='#BB00BB'>set_max_cache_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
            problem.<font color='#BB00BB'>set_c</font><font face='Lucida Console'>(</font>C<font face='Lucida Console'>)</font>;
            problem.<font color='#BB00BB'>set_epsilon</font><font face='Lucida Console'>(</font>eps<font face='Lucida Console'>)</font>;

            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num_nonnegative <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learn_nonnegative_weights<font face='Lucida Console'>)</font>
            <b>{</b>
                num_nonnegative <font color='#5555FF'>=</font> problem.<font color='#BB00BB'>get_num_dimensions</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
            <b>}</b>

            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>has_prior</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
            <b>{</b>
                svm_objective <font color='#5555FF'>=</font> <font color='#BB00BB'>solver</font><font face='Lucida Console'>(</font>problem, weights, num_nonnegative<font face='Lucida Console'>)</font>;
            <b>}</b>
            <font color='#0000FF'>else</font>
            <b>{</b>
                matrix<font color='#5555FF'>&lt;</font>scalar_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>df.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,dims<font face='Lucida Console'>)</font>;
                w_type <font color='#BB00BB'>b</font><font face='Lucida Console'>(</font>df.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
                temp <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
                b <font color='#5555FF'>=</font> <font color='#979000'>0</font>;

                <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> pad_size <font color='#5555FF'>=</font> dims<font color='#5555FF'>-</font>prior.weights.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
                <font color='#009900'>// Copy the prior into the temp and b matrices.  We have to do this row
</font>                <font color='#009900'>// by row copy because the new training data might have new labels we
</font>                <font color='#009900'>// haven't seen before and therefore the sizes of these matrices could be
</font>                <font color='#009900'>// different.
</font>                <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> prior.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
                <b>{</b>
                    <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> std::<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>df.labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, df.labels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, prior.labels[i]<font face='Lucida Console'>)</font><font color='#5555FF'>-</font>df.labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
                    <font color='#BB00BB'>set_rowm</font><font face='Lucida Console'>(</font>temp,r<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#BB00BB'>join_rows</font><font face='Lucida Console'>(</font><font color='#BB00BB'>rowm</font><font face='Lucida Console'>(</font>prior.weights,i<font face='Lucida Console'>)</font>, zeros_matrix<font color='#5555FF'>&lt;</font>scalar_type<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#979000'>1</font>,pad_size<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
                    <font color='#BB00BB'>b</font><font face='Lucida Console'>(</font>r<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> prior.<font color='#BB00BB'>b</font><font face='Lucida Console'>(</font>i<font face='Lucida Console'>)</font>;
                <b>}</b>

                <font color='#0000FF'>const</font> w_type prior_vect <font color='#5555FF'>=</font> <font color='#BB00BB'>reshape_to_column_vector</font><font face='Lucida Console'>(</font><font color='#BB00BB'>join_rows</font><font face='Lucida Console'>(</font>temp,b<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
                svm_objective <font color='#5555FF'>=</font> <font color='#BB00BB'>solver</font><font face='Lucida Console'>(</font>problem, weights, prior_vect<font face='Lucida Console'>)</font>;
            <b>}</b>


            df.weights <font color='#5555FF'>=</font> <font color='#BB00BB'>colm</font><font face='Lucida Console'>(</font><font color='#BB00BB'>reshape</font><font face='Lucida Console'>(</font>weights, df.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, dims<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>, <font color='#BB00BB'>range</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>,dims<font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
            df.b       <font color='#5555FF'>=</font> <font color='#BB00BB'>colm</font><font face='Lucida Console'>(</font><font color='#BB00BB'>reshape</font><font face='Lucida Console'>(</font>weights, df.labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, dims<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>, dims<font face='Lucida Console'>)</font>;
            <font color='#0000FF'>return</font> df;
        <b>}</b>

    <font color='#0000FF'>private</font>:

        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num_threads;
        scalar_type C;
        scalar_type eps;
        <font color='#0000FF'><u>bool</u></font> verbose;
        oca solver;
        <font color='#0000FF'><u>bool</u></font> learn_nonnegative_weights;

        trained_function_type prior;
    <b>}</b>;

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>


<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_SVm_MULTICLASS_LINEAR_TRAINER_H__
</font>

</pre></body></html>